# This code is essentially the same as the original version of Ben.
# The only difference is to change the risk factor from 60 to 30
# The additional changes are adding codes for saving Hrep matrices and adding comments.

# imports
using LightGraphs, IntervalArithmetic, SpatialIndexing, Plots, PyCall
import LazySets, Polyhedra

# create tree 2D
function create_tree(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}})
    tree = RTree{Float64,2}(Int)
    @inbounds for (id, (i,j)) in enumerate(xd)
        rec = SpatialIndexing.Rect((x1d[i].lo,x2d[j].lo),(x1d[i].hi,x2d[j].hi))
        insert!(tree, rec, id)
    end
    return tree
end

# select cells with path to scc
function ancestors(G::SimpleDiGraph{Int64}, src)
    reverse!(G)
    a = Vector{Int64}()
    @inbounds for (v, d) in enumerate(gdistances(G, src))
        if d < typemax(Int64)
            push!(a, v)
        end
    end
    reverse!(G)
    return a
end

# select nonleaving cells
function select_nonleaving_cells(G::SimpleDiGraph{Int64})
	scc = strongly_connected_components(G)
	idx = 0
    len = 0
    @inbounds for i in 1:length(scc)
        l = length(scc[i])
        l > len && (idx = i; len=l) 
    end
    da = ancestors(G, scc[idx])
    return union(scc[idx],da)
end

# construct 2D boxes for plotting
function construct_face(xd::Vector{NTuple{2, Int64}},xd1::Vector{Interval{Float64}},xd2::Vector{Interval{Float64}})
    boxes = Vector{typeof(LazySets.Hyperrectangle(low=[0.0,0.0],high=[1.,1.]))}()
    for (i,j) in xd
        push!(boxes,LazySets.Hyperrectangle(low=[xd1[i].lo,xd2[j].lo],high=[xd1[i].hi,xd2[j].hi]))
    end
    return boxes
end

# create graph
function create_graph(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}}, ud::Vector{Tuple{Interval{Float64}}}, w, model)
    tree = create_tree(xd, x1d, x2d)
    G = SimpleDiGraph(length(xd))
    u = ud[1]
    @inbounds for (id,(i,j)) in enumerate(xd)
        x = (x1d[i],x2d[j])
        y = model(x,u,w)
        yr = SpatialIndexing.Rect((y[1].lo,y[2].lo),(y[1].hi,y[2].hi))
        hits = SpatialIndexing.intersects_with(tree, yr)
        @inbounds for hit in hits
            if hit.val != id
                add_edge!(G,id,hit.val)
            end
        end
    end
    return G
end

# economic cost
function economic_cost(x,u)
    # Equation (40)
    if x[2].hi > 352.0 
        return 1*x[1] + 10*(x[2].hi-352.0)^2
    elseif x[2].lo < 348.0
        return 1*x[1] + 10*(348.0-x[2].lo)^2
    else
        return 1*x[1]
    end

end

# determine economic zone
function determine_economic_zone(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}}, ud::Vector{Tuple{Interval{Float64}}}, w::Vector{Interval{Float64}}, delta::Float64, M::Float64)
    Ce = Vector{NTuple{2, Int64}}()  # A1L2. Initialize an empty Ce. Ce is a subset of xd
    u = ud[1]
    for B in xd  # A1L3-L5
        x = [x1d[B[1]], x2d[B[2]]]
        # e = collect(model(x,u,w)) - collect(model(x,u,[0.0,0.0]))
        # println(e)
        le = economic_cost(x + 0.1*M*w,u) # - economic_cost(x + 0.0*epsilon*w,u)  # Question why is 0.1*M*w
        # println(le)
        if  le.hi <= delta
            push!(Ce,B)
        end 
    end
    return Ce
end


function computeEconomicZone(d, delta)

    # example test
    # Create intervals for computation
    x1 = @interval(0.0,1.)   # no zone in concentration. use entire state constraint
    x2 = @interval(348.0,352.0) # search within the zone temperature zone
    u1 = @interval(285.0,315.0)
    w1 = @interval(-0.1,0.1)  # nominal value 1
    w2 = @interval(-2.0,2.0)  # nominal value 350

    x1s = mince(x1,d)
    x2s = mince(x2,d)
    u1s = mince(u1,1)
    
    # CSTR model parameters
    nx = 2					# Number of states
    nu = 1					# Number of inputs
    nzx = 1					# Number of zone states
    nzu = 0					# Number of zone inputs
    nw = 2					# Number of disturbances
    q = 100.0				# inlet flow rate, L/min
    Tin = 350.0				# inlet temperature, K
    Cin = 1.0				# inlet concentration, kmol/m^3
    V = 100.0				# Volume of reactor, L
    # r = 0.219				# radius of reactor, m
    k0 = 7.2E10				# rate constant, 1/min
    EoR = 8750.0			# Activate energy/gas constant, K
    UA = 5.0E4				# heat transfer coefficient, J/min.K
    rho = 1000.0				# density, g/L
    Cp = 0.239				# heat capacity, J/g.K
    DH = -5.0E4				# heat of reaction, J/mol
    T = 0.1					# discretization step size, min

    f1de(x::Tuple{Interval{Float64},Interval{Float64}},u::Tuple{Interval{Float64}},w) = (x[1] + T*(( (q/V)*(Cin + w[1] - x[1]) - k0*exp(-EoR/x[2])*x[1] ) ), x[2] + T*( ( (q/V)*(Tin + w[2] - x[2]) + (-DH/(rho*Cp))*k0*exp(-EoR/x[2])*x[1] + (UA/(V*rho*Cp))*(-x[2]) ) +  (UA/(V*rho*Cp))*u[1] ))
  
    # create convexhull of target zone
    x12d = [(i,j) for i in 1:length(x1s) for j in 1:length(x2s)]
    # println(length(x12d))
    Xt = LazySets.ConvexHullArray(construct_face(x12d,x1s,x2s))
    u1d = [(u,) for u in u1s];
    w = [w1,w2]

    # Step 1: Determine economic zone within the target zone
    Ce = determine_economic_zone(x12d,x1s,x2s,u1d,w,delta,10.0)  # Ce is a subset of x12d
    println(length(Ce))
    Xe = LazySets.ConvexHullArray(construct_face(Ce,x1s,x2s))

    # Step 2: Determine cells that approximate the robust control invariant set  # A1L6-L12
    # Using condition in algorithm 2 is equivalent to increasing the size of the disturbance set
    e = (0.005,0.03)
    W = [(0.0,0.0),(-0.1-e[1],-2.0-e[2]),(0.1+e[1],-2.0-e[2]),(-0.1-e[1],2.0+e[2]),(0.1+e[1],2.0+e[2])]

    Xes = []
    
    for w in W
        G = create_graph(Ce, x1s, x2s, u1d, w, f1de)
        nlc = select_nonleaving_cells(G)
        push!(Xes,nlc)
    end

    # find intersections
    Cer = Xes[1]
    for xe in Xes
        Cer = intersect(Cer,xe)
    end

    Cr = Ce[Cer]

    @show x1s[1]
    @show x2s[1]

    Xr = LazySets.ConvexHullArray(construct_face(Cr,x1s,x2s))

    return Xr,Xe,Xt

end

Xr,Xe,Xt = computeEconomicZone(256,60.0)  # (initial cell diameter, risk factor)

Vr = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xr)))))
Ve = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xe)))))
Vt = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xt)))))

# # Extract a and b from Vr and save it in numpy format
# @show Vr
# np = pyimport("numpy")
# hrepA = Vector()
# hrepB = Vector()
# for i in 1:length(Vr)
#     push!(hrepA, Vr[i].a)
#     push!(hrepB, Vr[i].b)
# end
# hrepAnp = np.asarray(hrepA)
# hrepBnp = np.asarray(hrepB)
# np.save("results/hrepA.npy", hrepAnp)
# np.save("results/hrepB.npy", hrepBnp)

plot(Xt, label="Target Zone", ylims=(345,355), yticks = 345:2:355, xlabel="C_A", ylabel="T")
plot!(Xe, label="Economic Zone")
plot!(Xr, label="Economic CIS")

println("done")